Skip to content

ENH: jax_autojit #284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

ENH: jax_autojit #284

wants to merge 3 commits into from

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Apr 25, 2025

Supercharge xpx.testing.lazy_xp_function.

This is a fairly high level change; @rgommers @pearu @jakevdp I could use your feedback.

Matching SciPy PR: scipy/scipy#22909

Automatic static arguments

scipy has seen a proliferation of lazy_xp_function(func, static_argnames=(...)) in the initial section of many test modules. This information is only useful for JAX and is quite verbose.
With scipy/scipy#22686, this informational quirk that is both specific to JAX and to unit tests moves to the implementation modules.

With this PR, lazy_xp_function no longer accepts parameters static_argnames and static_argnums. Instead, all arguments that are not JAX arrays are automatically treated as static. Note that this behaviour is generally desirable in unit testing but a bad idea in production. Consider:

def f(x: Array, y: float, plus: bool) -> Array:
    return x + y if plus else x - y

j1 = jax.jit(f, static_argnames="plus")
j2 = jax_autojit(f)

jax_autojit is a new internal function of array-api-extra.
In the above example, j2 requires a lot less setup to be tested effectively, but on the flip side it means that it will be re-traced for every different value of y, which likely makes it not fit for purpose in production.

To clarify: jax_autojit is applied and removed on the fly within tests with the xp fixture and it is never used outside of unit testing.

Wrapped inputs and output

There are a few cases of scipy functions returning bespoke containers with arrays inside instead of simple tuples, namedtuples, lists, or dicts of arrays.

Two such examples are

In main, these can't be tested with lazy_xp_function, and as of scipy/scipy#22686 the issue also impacts documentation.

This PR lifts this restriction and allows completely arbitrary objects as parameters and as return values of the functions. If these objects internally contain JAX arrays, lazy_xp_function will now automatically extract them, pass them through the JIT, and reassemble everything for the test function to observe the result.

The rationale is that, in real life, users are unlikely to wrap the scipy functions with jax.jit directly; instead they are more likely to consume their outputs in their own functions and then wrap those with jax.jit:

@jax.jit
def my_user_function(x):
    y = scipy.stats.ttest_ind(x)
    return my_user_consume(y)  # returns array or tuple of arrays

Static non-hashable arguments

Non-hashable objects can now be static.
Consider:

def f(x: Array, verbs: Sequence[str]) -> Array:
    ...
    
lazy_xp_function(f, static_argnames="verbs")

def test_f(xp):
    x = ...
    y = f(x, ("foo", "bar", "baz"))  # OK
    y = f(x, ["foo", "bar", "baz"])  # Fails

This PR fixes it; all objects in the input and output of the function need only be hashable or pickleable.

Dask materialization raises inside a container

This also fixes a Dask-specific bug where the graph materialization raises:

@pytest.mark.skip_xp_backends("jax.numpy", reason="raise inside jax.pure_callback")
def test1(xp):
    with pytest.raises(Exception):
        f(x)

The above works when f returns plain Dask array objects or tuples or lists thereof, even if the graph would otherwise be discarded and never computed, thanks to the lazy_xp_function machinery, but used to fail when the return value is an opaque container with Dask arrays inside. This PR fixes it.

"""

@override
def persistent_id(self, obj: object) -> Literal[0, 1, None]: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To say that I am aggrieved by pyright and numpydoc would be a gentle understatement at this point.

Comment on lines +509 to +512
"""
Register upon first use instead of at import time, to avoid
globally importing JAX.
"""
Copy link
Contributor Author

@crusaderky crusaderky Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside note on design: Dask avoids this exact problem by not requiring any decorator and instead duck-type checking for uniquely named dunder methods called __dask_<...>__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: everything in this module is a private helper.

Comment on lines +42 to +48
class Deprecated(enum.Enum):
"""Unique type for deprecated parameters."""

DEPRECATED = 1


DEPRECATED = Deprecated.DEPRECATED
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find anything to help here in the standard library, and I find that odd. Did I miss anything? I was expecting something like @functools.deprecated_args("static_argnames", "static_argnums").

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is https://docs.python.org/3/library/warnings.html#warnings.deprecated that could be used in conjunction with typing.overload.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good pointer; however it requires Python >=3.13 so I can't use it 😞

Copy link

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a minor nit and a note, otherwise it looks good to me as it reduces the maintenance burden of adding/updating static_argnames/argnums arguments.

Btw, I enjoyed reading pickle_flatten and pickle_flatten implementations as it felt like less is more.

Thanks, @crusaderky!

Comment on lines +42 to +48
class Deprecated(enum.Enum):
"""Unique type for deprecated parameters."""

DEPRECATED = 1


DEPRECATED = Deprecated.DEPRECATED
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is https://docs.python.org/3/library/warnings.html#warnings.deprecated that could be used in conjunction with typing.overload.

@jakevdp
Copy link

jakevdp commented Apr 27, 2025

For what it's worth, one of the main reasons we haven't implemented something like this in JAX is because it tends to negatively impact dispatch times. The output of jax.jit is a C++ level callable, that directly dispatches to the compiled kernel after the initial call. This is important because if you put a layer of Python logic in the dispatch path to inspect all the arguments and determine whether to treat them as static or not (and possibly re-compile based on this information) then it severely impacts dispatch times.

With that background, I'd like a bit of clarification here: is this intended for use only in testing paths, or do you imagine this will be used within the dispatch path for libraries like scipy that implement the Array API?

@lucascolley
Copy link
Member

I think this section of the top post answers your question, Jake:

To clarify: jax_autojit is applied and removed on the fly within tests with the xp fixture and it is never used outside of unit testing.

crusaderky and others added 2 commits April 28, 2025 09:50
@crusaderky
Copy link
Contributor Author

scipy/scipy#22909 does not cause any issues to crop up.
This PR is ready for final review.

@crusaderky crusaderky marked this pull request as ready for review April 29, 2025 10:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants